import matplotlib.pyplot as plt
import numpy as np
import bluepysnap

colors_per_mtype = {
    "Rt_RC": "#1f77b4",
    "VPL_TC": "#ff7f02",
    "VPL_IN": "#2ca02c",
    "mc2_Column": "k",
    "mc2;VPL": "k",
    "mc2;Rt": "#1f77b4",
}


def plot_sample_traces(sim, fig_suffix, ts, te):

    # randomly chosen good looking traces (based on variant1_Ecel_SUPER_LONG_100percent simulation)
    RT_presyn_cells = [29305, 29332, 29336]

    # get 3 sample postsyn TCs
    TC_postsyn_cells = (
        sim.circuit.edges["thalamus_neurons__thalamus_neurons__chemical"]
        .efferent_nodes(RT_presyn_cells)
        .tolist()
    )
    spikes_df_TC = sim.spikes.filter(group=TC_postsyn_cells, t_start=ts, t_stop=te).report
    spikining_postsynTC = np.unique(spikes_df_TC.ids).tolist()
    TC_postsyn_cells_sample = spikining_postsynTC[4:7]  # select 3 representative traces

    # plotting
    num_rows = len(RT_presyn_cells) + len(TC_postsyn_cells_sample)
    fig, axes = plt.subplots(num_rows, 1, figsize=(6, 5), sharex=True, sharey=True)  # rows, cols

    # RT
    for i, node in enumerate(RT_presyn_cells):
        ax = axes[i]
        filtered = sim.reports["soma_report"].filter(
            group={"node_id": node, "population": "thalamus_neurons"}, t_start=ts, t_stop=te
        )
        ax.plot(filtered.report.T.mean(), lw=0.8, c=colors_per_mtype["Rt_RC"])
        ax.spines.top.set_visible(False)
        ax.spines.right.set_visible(False)
        for t in [3500, 4500]:
            ax.axvspan(t, t + 500, color="grey", alpha=0.2)
        ax.set_xlim(ts, te)
        ax.set_ylim(-100, 30)

    # TC
    for j, node in enumerate(TC_postsyn_cells_sample):
        ax1 = axes[i + 1 + j]
        filtered = sim.reports["soma_report"].filter(
            group={"node_id": node, "population": "thalamus_neurons"}, t_start=ts, t_stop=te
        )
        ax1.plot(filtered.report.T.mean(), lw=0.8, c=colors_per_mtype["Rt_RC"])
        ax1.get_lines()[0].set_color(colors_per_mtype["VPL_TC"])

        ax1.spines.top.set_visible(False)
        ax1.spines.right.set_visible(False)

        for t in [3500, 4500]:
            ax1.axvspan(t, t + 500, color="grey", alpha=0.2)

        ax1.set_xlim(ts, te)
        ax1.set_ylim(-100, 30)

    fig.tight_layout()
    fig.savefig(f"sample_traces_{fig_suffix}.pdf")


if __name__ == "__main__":

    circuit_variant = "variant1_Ecel_SUPER_LONG_100percent"
    master_sim = bluepysnap.Simulation(
        "../../simulations/08_08_Cav_variants/variant1_Ecel/"
        + circuit_variant
        + "_CT_up_down/simulation_config.json"
    )

    plot_sample_traces(master_sim, circuit_variant, 2000, 5500)
